HipKittens MXFP8 GEMM Support#566
Conversation
| [](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) { | ||
| return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); | ||
| return MKN(std::get<0>(info.param)) + "x" + | ||
| std::to_string(std::get<1>(info.param)) + "x" + |
There was a problem hiding this comment.
What is a point, they are set to false only
|
|
||
| return torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device) | ||
| key = (device, ub, grouped_gemm) | ||
| ws = _workspace_cache.get(key) |
There was a problem hiding this comment.
Why we don't rely on torch memory caching?
There was a problem hiding this comment.
I have made this change. I will need to run an E2E run to make sure that performance isn't affected, but should be ok given my understanding of torch.empty()
There was a problem hiding this comment.
It doesn't seem changed.
| if (use_hipkittens) { | ||
| auto param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); | ||
|
|
||
| hipStream_t s = use_service_stream ? ss_ctl.stream : stream; |
There was a problem hiding this comment.
the same like with is_mxfp8, no point of having it defined for one branch only
| @@ -743,12 +786,15 @@ MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16) | |||
|
|
|||
| INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, | |||
There was a problem hiding this comment.
If you end up with having separate prefix for MXFP8, it has to be use for this suite for consistency
ipanfilo
left a comment
There was a problem hiding this comment.
Some comments in rocm_gemm and test_cublas_gemm are still open
Creates an MXFP8 GEMM with HipKittens that outperforms hipBLASlt, and offers additional epilogues such as BIAS and GELU AUX
Requires a workspace sized relative to the model. Often larger than hipBLASlt, but with significant performance improvements. Only builds for gfx950, and requires M / 256 and N / 256.
Adds hipKittens header library as a submodule.